/**
 * Copyright 2021 Huawei Technologies Co., Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "op_def_builder.h"
namespace ge {
OpDefBuilder::OpDefBuilder(GraphPoolBuilder *pool_builder) : pool_builder_(pool_builder) {
  if (pool_builder_ == nullptr) {
    error_code_ = FAILED;
  }
}
OpDefBuilder &OpDefBuilder::SetType(const char *type) {
  if (type == nullptr || pool_builder_ == nullptr) {
    error_code_ = FAILED;
    return *this;
  }
  type_ = pool_builder_->AddString(type);
  return *this;
}
OpDefBuilder &OpDefBuilder::SetInferShapeFunc(InferShapeFunc func) {
  if (pool_builder_ == nullptr) {
    error_code_ = FAILED;
    return *this;
  }
  infer_shape_func_ = func;
  return *this;
}
OpDefBuilder &OpDefBuilder::SetTilingFunc(TilingFunc func) {
  if (pool_builder_ == nullptr) {
    error_code_ = FAILED;
    return *this;
  }
  tiling_func_ = func;
  return *this;
}
OpDefBuilder &OpDefBuilder::SetInputDefCount(size_t num) {
  inputs_def_.resize(num);
  return *this;
}
OpDefBuilder &OpDefBuilder::SetInputDef(int index, const IoDef &def) {
  if (pool_builder_ == nullptr || index < 0 || index >= inputs_def_.size()) {
    error_code_ = FAILED;
    return *this;
  }

  inputs_def_[index].e_type = def.e_type;
  inputs_def_[index].name = pool_builder_->AddString(def.name);

  return *this;
}
OpDefBuilder &OpDefBuilder::SetOutputDefCount(size_t num) {
  outputs_def_.resize(num);
  return *this;
}
OpDefBuilder &OpDefBuilder::SetOutputDef(int index, const IoDef &def) {
  if (pool_builder_ == nullptr || index < 0 || index >= outputs_def_.size()) {
    error_code_ = FAILED;
    return *this;
  }

  outputs_def_[index].e_type = def.e_type;
  outputs_def_[index].name = pool_builder_->AddString(def.name);

  return *this;
}

OpDefBuilder &OpDefBuilder::SetAttrDefCount(size_t num) {
  attrs_def_.resize(num);
  return *this;
}
OpDefBuilder &OpDefBuilder::SetAttrDef(int index, const AttrDef &def) {
  if (pool_builder_ == nullptr || index < 0 || index >= attrs_def_.size()) {
    error_code_ = FAILED;
    return *this;
  }

  attrs_def_[index].e_type = def.e_type;
  attrs_def_[index].name = pool_builder_->AddString(def.name);
  attrs_def_[index].creator = def.default_creator;

  return *this;
}
Status OpDefBuilder::Build(const GraphPoolReader *pool_reader, OpDef *op_def) const {
  if (error_code_ != SUCCESS) {
    return error_code_;
  }
  op_def->type = pool_reader->GetString(type_);
  op_def->infer_shape_func = infer_shape_func_;
  op_def->tiling_func = tiling_func_;

  op_def->inputs_def.resize(inputs_def_.size());
  for (size_t i = 0; i < inputs_def_.size(); ++i) {
    op_def->inputs_def[i].e_type = inputs_def_[i].e_type;
    op_def->inputs_def[i].name = pool_reader->GetString(inputs_def_[i].name);
  }

  op_def->outputs_def.resize(outputs_def_.size());
  for (size_t i = 0; i < outputs_def_.size(); ++i) {
    op_def->outputs_def[i].e_type = outputs_def_[i].e_type;
    op_def->outputs_def[i].name = pool_reader->GetString(outputs_def_[i].name);
  }

  op_def->attrs_def.resize(attrs_def_.size());
  for (size_t i = 0; i < attrs_def_.size(); ++i) {
    op_def->attrs_def[i].e_type = attrs_def_[i].e_type;
    op_def->attrs_def[i].name = pool_reader->GetString(attrs_def_[i].name);
    op_def->attrs_def[i].default_creator = attrs_def_[i].creator;
  }

  return SUCCESS;
}

}